{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qa5r1-7RmS8j"
   },
   "source": [
    "# Attribution Demo \n",
    "\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/safety-research/circuit-tracer/blob/main/demos/attribute_demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "In this demo, you'll learn how to load models and perform attribution on them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Colab Setup Environment\n",
    "\n",
    "try:\n",
    "    import google.colab\n",
    "    !mkdir -p repository && cd repository && \\\n",
    "     git clone https://github.com/safety-research/circuit-tracer && \\\n",
    "     curl -LsSf https://astral.sh/uv/install.sh | sh && \\\n",
    "     uv pip install -e circuit-tracer/\n",
    "\n",
    "    import sys\n",
    "    from huggingface_hub import notebook_login\n",
    "    sys.path.append('repository/circuit-tracer')\n",
    "    sys.path.append('repository/circuit-tracer/demos')\n",
    "    notebook_login(new_session=False)\n",
    "    IN_COLAB = True\n",
    "except ImportError:\n",
    "    IN_COLAB = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "P8fNhpqzmS8k"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import torch\n",
    "\n",
    "from circuit_tracer import ReplacementModel, attribute\n",
    "from circuit_tracer.utils import create_graph_files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZN_3kEyfmS8k"
   },
   "source": [
    "First, load your model and transcoders by name. `model_name` is a normal HuggingFace / [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) model name; we'll use `google/gemma-2-2b`. We set `transcoder_name` to `gemma`, which is shorthand for the [Gemma Scope](https://arxiv.org/abs/2408.05147) transcoders; we take the transcoders with lowest L0 (mean # of active features) for each layer.\n",
    "\n",
    "We additionally support `model_name = \"meta-llama/Llama-3.2-1B\"`, with `\"llama\"` transcoders; these are ReLU skip-transcoders that we trained, available [here](https://huggingface.co/mntss/skip-transcoder-Llama-3.2-1B-131k-nobos/tree/new-training).\n",
    "\n",
    "If you want to use other models, you'll have to provide your own transcoders. To do this, set `transcoder_name` to point to your own configuration file, specifying the list of transcoders that you want to use. You can see `circuit_tracer/configs` for example configs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BBsETpl0mS8l"
   },
   "outputs": [],
   "source": [
    "model_name = 'google/gemma-2-2b'\n",
    "transcoder_name = \"gemma\"\n",
    "model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dcZNR0egmS8l"
   },
   "source": [
    "Next, set your attribution arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5XBwNyq4mS8l"
   },
   "outputs": [],
   "source": [
    "prompt = \"The capital of state containing Dallas is\"  # What you want to get the graph for\n",
    "max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter\n",
    "desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)\n",
    "max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.\n",
    "batch_size=256  # Batch size when attributing\n",
    "offload='disk' if IN_COLAB else 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)\n",
    "verbose = True  # Whether to display a tqdm progress bar and timing report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VXfD-5GrmS8l"
   },
   "source": [
    "Then, just run attribution!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wx2XiXVjmS8l"
   },
   "outputs": [],
   "source": [
    "graph = attribute(\n",
    "    prompt=prompt,\n",
    "    model=model,\n",
    "    max_n_logits=max_n_logits,\n",
    "    desired_logit_prob=desired_logit_prob,\n",
    "    batch_size=batch_size,\n",
    "    max_feature_nodes=max_feature_nodes,\n",
    "    offload=offload,\n",
    "    verbose=verbose\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RUn1YKnUmS8l"
   },
   "source": [
    "We now have a graph object! We can save it as a .pt file, but be warned that it's large (~167MB)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2tLE4FzdmS8m"
   },
   "outputs": [],
   "source": [
    "graph_dir = 'graphs'\n",
    "graph_name = 'example_graph.pt'\n",
    "graph_dir = Path(graph_dir)\n",
    "graph_dir.mkdir(exist_ok=True)\n",
    "graph_path = graph_dir / graph_name\n",
    "\n",
    "graph.to_pt(graph_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w3cdLLfJmS8m"
   },
   "source": [
    "Given this object, we can create the graph files that we need to visualize the graph. Give it a slug (name), and set the node / edge thresholds for pruning. Pruning removes unimportant nodes and edges from your graph; lower thresholds (i.e., more aggressive pruning) results in smaller graphs. These may be easier to interpret, but explain less of the model's behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vh8HPtimmS8m"
   },
   "outputs": [],
   "source": [
    "slug = \"dallas-austin\"  # this is the name that you assign to the graph\n",
    "graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you\n",
    "node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8\n",
    "edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98\n",
    "\n",
    "create_graph_files(\n",
    "    graph_or_path=graph_path,  # the graph to create files for\n",
    "    slug=slug,\n",
    "    output_path=graph_file_dir,\n",
    "    node_threshold=node_threshold,\n",
    "    edge_threshold=edge_threshold\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EQuFE-eimS8m"
   },
   "source": [
    "Now, you can visualize the graph using the following commands! This will spin up a local server to act as the frontend.\n",
    "\n",
    "**If you're running this notebook on a remote server, make sure that you set up port forwarding, so that the chosen port is accessible on your local machine too.**\n",
    "\n",
    "You can select nodes by clicking on them. Ctrl/Cmd+Click on nodes to pin and unpin them to your subgraph. G+Click on nodes in the subgraph to group them together into a supernode; G+Click on the X next to a supernode to dissolve it. Click on the edit button to edit node descriptions, and click on supernode description to edit that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gMZ8Ee-KmS8m"
   },
   "outputs": [],
   "source": [
    "from circuit_tracer.frontend.local_server import serve\n",
    "\n",
    "\n",
    "port = 8046\n",
    "server = serve(data_dir='./graph_files/', port=port)\n",
    "\n",
    "if IN_COLAB:\n",
    "    from google.colab import output as colab_output  # noqa\n",
    "    colab_output.serve_kernel_port_as_iframe(port, path='/index.html', height='800px', cache_in_notebook=True)\n",
    "else:\n",
    "    from IPython.display import IFrame\n",
    "    print(f\"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'\")\n",
    "    display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yDGiO8jBmS8m"
   },
   "source": [
    "Once you're done, you can stop the server with the following command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "185O1Ck1mS8m"
   },
   "outputs": [],
   "source": [
    "# server.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "98579UbGmS8m"
   },
   "source": [
    "Congrats, you're done! Go to `intervention_demo.ipynb` to see how to perform interventions, or check out `gemma_demo.ipynb` and `llama_demo.ipynb` for examples of worked-out test examples. Read on for a bit more info aabout the Graph class and pruning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tkgM1cBCmS8m"
   },
   "source": [
    "## Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IGnU9l1zmS8m"
   },
   "source": [
    "Earlier, you created a graph object. Its adjacency matrix / edge weights are stored in `graph.adjacency_matrix` in a dense format; rows are target nodes and columns are source nodes. The first `len(graph.real_features)` entries of the matrix represent features; the `i`th entry corresponds to the `i`th feature in `graph.real_features`, given in `(layer, position, feature_idx)` format. The next `graph.cfg.n_layers * graph.n_pos` entries are error_nodes. The next `graph.n_pos` entries are token nodes. The final `len(graph.logit_tokens)` entries are logit nodes.\n",
    "\n",
    "The value of the cell `graph.adjacency_matrix[target, source]` is the direct effect of the source node on the target node. That is, it tells you how much the target node's value would change if the source node were set to 0, while holding the attention patterns, layernorm denominators, and other feature activations constatnt. Thus, if the target node is a feature, this tells you how much the target feature would change; if the target node is a logit, this tells you how much the (de-meaned) value of the logit would change.\n",
    "\n",
    "Note that `gemma-2-2b` is model (family) that uses logit softcapping. This means that a softcap function, `softcap(x) = t * tanh(x/t)` is used to constrain the logits to fall within (-t, t); `gemma-2-2b` uses `t=30`. For such models, we predict the change in logits *pre-softcap*, as the nonlinearity introduced by softcapping would cause our attribution to yield incorrect / approximate direct effect values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IWTV8i9zmS8n"
   },
   "source": [
    "### Pruning\n",
    "Given a graph, you might want to prune it, as it will otherwise contain many low-impact nodes and edges that clutter the circuit diagram while adding little information. We enable you to prune nodes by absolute influence, i.e. the total impact that the nodes have on the logits, direct and indirect. The default threshold is 0.8: this means we will keep the minimum number of nodes required to capture 80% of all logit effects. Similarly, the edge_threshold, by default 0.98, means that we will keep the minimum number of edges required to capture 98% of all logit effects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GmKhWpuUmS8n"
   },
   "outputs": [],
   "source": [
    "from circuit_tracer.graph import prune_graph\n",
    "prune_graph(graph, node_threshold=0.7, edge_threshold=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uCo4FSQwqcBl"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}